// Copyright (c) 1999-2004 Brian Wellington (bwelling@xbill.org)

package org.xbill.DNS;

import java.io.IOException;
import java.security.SecureRandom;
import java.util.Random;
import lombok.SneakyThrows;

/**
 * A DNS message header
 *
 * @see Message
 * @author Brian Wellington
 */
public class Header implements Cloneable {

  private int id;
  private int flags;
  private int[] counts;

  private static final Random random = new SecureRandom();

  /** The length of a DNS Header in wire format. */
  public static final int LENGTH = 12;

  /**
   * Create a new empty header.
   *
   * @param id The message id
   */
  public Header(int id) {
    this();
    setID(id);
  }

  /** Create a new empty header with a random message id */
  public Header() {
    counts = new int[4];
    flags = 0;
    id = -1;
  }

  /** Parses a Header from a stream containing DNS wire format. */
  Header(DNSInput in) throws IOException {
    this(in.readU16());
    flags = in.readU16();
    for (int i = 0; i < counts.length; i++) {
      counts[i] = in.readU16();
    }
  }

  /**
   * Creates a new Header from its DNS wire format representation
   *
   * @param b A byte array containing the DNS Header.
   */
  public Header(byte[] b) throws IOException {
    this(new DNSInput(b));
  }

  void toWire(DNSOutput out) {
    out.writeU16(getID());
    out.writeU16(flags);
    for (int count : counts) {
      out.writeU16(count);
    }
  }

  public byte[] toWire() {
    DNSOutput out = new DNSOutput();
    toWire(out);
    return out.toByteArray();
  }

  private static boolean validFlag(int bit) {
    return bit >= 0 && bit <= 0xF && Flags.isFlag(bit);
  }

  private static void checkFlag(int bit) {
    if (!validFlag(bit)) {
      throw new IllegalArgumentException("invalid flag bit " + bit);
    }
  }

  static int setFlag(int flags, int bit, boolean value) {
    checkFlag(bit);

    // bits are indexed from left to right
    if (value) {
      return flags | (1 << (15 - bit));
    } else {
      return flags & ~(1 << (15 - bit));
    }
  }

  /**
   * Sets a flag to the supplied value
   *
   * @see Flags
   */
  public void setFlag(int bit) {
    checkFlag(bit);
    flags = setFlag(flags, bit, true);
  }

  /**
   * Sets a flag to the supplied value
   *
   * @see Flags
   */
  public void unsetFlag(int bit) {
    checkFlag(bit);
    flags = setFlag(flags, bit, false);
  }

  /**
   * Retrieves a flag
   *
   * @see Flags
   */
  public boolean getFlag(int bit) {
    checkFlag(bit);
    // bits are indexed from left to right
    return (flags & (1 << (15 - bit))) != 0;
  }

  boolean[] getFlags() {
    boolean[] array = new boolean[16];
    for (int i = 0; i < array.length; i++) {
      if (validFlag(i)) {
        array[i] = getFlag(i);
      }
    }
    return array;
  }

  /** Retrieves the message ID */
  public int getID() {
    synchronized (random) {
      if (id < 0) {
        id = random.nextInt(0xffff);
      }
      return id;
    }
  }

  /** Sets the message ID */
  public void setID(int id) {
    if (id < 0 || id > 0xffff) {
      throw new IllegalArgumentException("DNS message ID " + id + " is out of range");
    }
    this.id = id;
  }

  /**
   * Sets the message's rcode
   *
   * @see Rcode
   */
  public void setRcode(int value) {
    if (value < 0 || value > 0xF) {
      throw new IllegalArgumentException("DNS Rcode " + value + " is out of range");
    }
    flags &= ~0xF;
    flags |= value;
  }

  /**
   * Retrieves the mesasge's rcode
   *
   * @see Rcode
   */
  public int getRcode() {
    return flags & 0xF;
  }

  /**
   * Sets the message's opcode
   *
   * @see Opcode
   */
  public void setOpcode(int value) {
    if (value < 0 || value > 0xF) {
      throw new IllegalArgumentException("DNS Opcode " + value + "is out of range");
    }
    flags &= 0x87FF;
    flags |= value << 11;
  }

  /**
   * Retrieves the mesasge's opcode
   *
   * @see Opcode
   */
  public int getOpcode() {
    return (flags >> 11) & 0xF;
  }

  void setCount(int field, int value) {
    if (value < 0 || value > 0xFFFF) {
      throw new IllegalArgumentException("DNS section count " + value + " is out of range");
    }
    counts[field] = value;
  }

  void incCount(int field) {
    if (counts[field] == 0xFFFF) {
      throw new IllegalStateException("DNS section count cannot be incremented");
    }
    counts[field]++;
  }

  void decCount(int field) {
    if (counts[field] == 0) {
      throw new IllegalStateException("DNS section count cannot be decremented");
    }
    counts[field]--;
  }

  /**
   * Retrieves the record count for the given section
   *
   * @see Section
   */
  public int getCount(int field) {
    return counts[field];
  }

  int getFlagsByte() {
    return flags;
  }

  /** Converts the header's flags into a String */
  public String printFlags() {
    StringBuilder sb = new StringBuilder();

    for (int i = 0; i < 16; i++) {
      if (validFlag(i) && getFlag(i)) {
        sb.append(Flags.string(i));
        sb.append(" ");
      }
    }
    return sb.toString();
  }

  String toStringWithRcode(int newrcode) {
    StringBuilder sb = new StringBuilder();

    sb.append(";; ->>HEADER<<- ");
    sb.append("opcode: ").append(Opcode.string(getOpcode()));
    sb.append(", status: ").append(Rcode.string(newrcode));
    sb.append(", id: ").append(getID());
    sb.append("\n");

    sb.append(";; flags: ").append(printFlags());
    sb.append("; ");
    for (int i = 0; i < 4; i++) {
      sb.append(Section.string(i)).append(": ").append(getCount(i)).append(" ");
    }
    return sb.toString();
  }

  /** Converts the header into a String */
  @Override
  public String toString() {
    return toStringWithRcode(getRcode());
  }

  /* Creates a new Header identical to the current one */
  @Override
  @SneakyThrows(CloneNotSupportedException.class)
  public Header clone() {
    Header h = (Header) super.clone();
    h.id = id;
    h.flags = flags;
    h.counts = new int[h.counts.length];
    System.arraycopy(counts, 0, h.counts, 0, counts.length);
    return h;
  }
}
